{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Imports\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "from batchgenerators.dataloading.data_loader import SlimDataLoaderBase\n",
    "from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter\n",
    "\n",
    "import numpy as np\n",
    "import time\n",
    "import math"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Defining an \"Epoch Dataloader with Shuffle\"\n",
    "---\n",
    "This is a dataloader as defined generally in most deep learning framework *with*\n",
    "all the benefits of `batchgenerators`. The dataloader, using the `SlimDataLoaderBase`\n",
    "from `batchgenerators`, has the following properties:\n",
    "\n",
    "1. Each epoch covers the dataset once\n",
    "2. Each new epoch 'deterministically' shuffles the data before creating batches\n",
    "3. Using `MultiThreadedAugmentor` to use multiple processes to load and transform\n",
    "data _while_ using batch size > 1\n",
    "\n",
    "Comments have been added throughout the code to explain important parts of writing\n",
    "a dataloader with `batchgenerators`. \n",
    "\n",
    "Also highlighted is the interplay between\n",
    "`SlimDataLoaderBase` and `MultithreadedAugmentor`, the later of which is used for rapid mini-batched\n",
    "loading and augmenting."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EpochDLWithShuffle(SlimDataLoaderBase):\n",
    "    def __init__(self, data, num_threads_in_mt=12, batch_size=4):\n",
    "        # This initializes self._data, self.batch_size and self.number_of_threads_in_multithreaded\n",
    "        super(EpochDLWithShuffle, self).__init__(data, batch_size, num_threads_in_mt)\n",
    "\n",
    "        self.num_restarted = 0\n",
    "        self.current_position = 0\n",
    "        self.was_initialized = False\n",
    "\n",
    "    def reset(self):\n",
    "        # Prevents the random order for each epoch being the same\n",
    "        rs = np.random.RandomState(self.num_restarted)\n",
    "        \n",
    "        # Here the data is shuffled but one can easily replace this with a \n",
    "        # shuffle of indices for when one wants to load the data while generating\n",
    "        # a batch in real-time, for example.\n",
    "        #\n",
    "        # Eg. rs.shuffle(self._data_indices) instead of line below\n",
    "        rs.shuffle(self._data)\n",
    "        self.was_initialized = True\n",
    "        self.num_restarted = self.num_restarted + 1\n",
    "\n",
    "        # Select a starting point for this subprocess. The self.thread_id is set by\n",
    "        # MultithreadedAugmentor and is in the range [0, num_of_threads_in_mt)\n",
    "        # Multiplying it with batch_size gives every subprocess a unique starting \n",
    "        # point WHILE taking into consideration the size of the batch\n",
    "        self.current_position = self.thread_id*self.batch_size\n",
    "\n",
    "    def generate_train_batch(self):\n",
    "        # This method HAS to be defined and is used to return batches.\n",
    "\n",
    "        # For doing the initialization in each subprocess generated by MultiThreadedAugmentor\n",
    "        if not self.was_initialized:\n",
    "            self.reset()\n",
    "        \n",
    "        # This will be used for the batch starting point in this loop\n",
    "        idx = self.current_position\n",
    "\n",
    "        if idx < len(self._data):\n",
    "            \n",
    "            # Next starting point. This skips the length of one batch for\n",
    "            # this process AS WELL AS all the other processes (i.e, self.number_of_threads_in_multithreaded)\n",
    "            # Since the processes already have unique (but contiguous) starting \n",
    "            # points due to the initialization of self.current_position in \n",
    "            # reset(), they continue to not overlap.\n",
    "            self.current_position = idx + self.batch_size*self.number_of_threads_in_multithreaded\n",
    "\n",
    "            # Having assured that the next starting point is safe, we simply\n",
    "            # return the next batch. Additionally, we take into consideration\n",
    "            # that the idx+batch_size might exceed the dataset size so we take \n",
    "            # min(len(self._data),idx+self.batch_size) as the end index\n",
    "            return self._data[idx: min(len(self._data),idx+self.batch_size)]\n",
    "        else:\n",
    "            self.was_initialized=False\n",
    "            raise StopIteration"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MultiThreadedAugmentor and running the Dataloader\n",
    "---\n",
    "\n",
    "The `MultiThreadedAugmentor` allows for rapid data loading and augmentation\n",
    "(`batchgenerators` has these as well) using multiple sub-processes. This is designed\n",
    "to be seamlessly used with the dataloader defined above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total samples in Dataset: 107\n",
      "Batch 1:[84, 10, 75, 2]\n",
      "Batch 2:[24, 99, 106, 7]\n",
      "Batch 3:[16, 86, 68, 22]\n",
      "Batch 4:[45, 60, 76, 52]\n",
      "Batch 5:[13, 73, 85, 54]\n",
      "Batch 6:[102, 8, 26, 92]\n",
      "Batch 7:[33, 3, 66, 48]\n",
      "Batch 8:[30, 6, 78, 94]\n",
      "Batch 9:[89, 93, 100, 59]\n",
      "Batch 10:[27, 18, 61, 51]\n",
      "Batch 11:[63, 71, 43, 1]\n",
      "Batch 12:[79, 42, 41, 4]\n",
      "Batch 13:[15, 17, 40, 38]\n",
      "Batch 14:[5, 53, 98, 56]\n",
      "Batch 15:[0, 34, 28, 55]\n",
      "Batch 16:[50, 11, 62, 35]\n",
      "Batch 17:[23, 31, 96, 57]\n",
      "Batch 18:[82, 91, 32, 90]\n",
      "Batch 19:[14, 74, 19, 29]\n",
      "Batch 20:[49, 104, 105, 69]\n",
      "Batch 21:[80, 20, 101, 72]\n",
      "Batch 22:[77, 25, 37, 81]\n",
      "Batch 23:[46, 97, 39, 65]\n",
      "Batch 24:[58, 12, 95, 88]\n",
      "Batch 25:[70, 87, 36, 21]\n",
      "Batch 26:[83, 9, 103, 67]\n",
      "Batch 27:[64, 47, 44]\n",
      "Total samples returned by augmentor: 107\n",
      "Coverage of samples returned by augmentor of original dataset: 107/107\n"
     ]
    }
   ],
   "source": [
    "# Some random data. Deliberately choosing odd length\n",
    "data = list(range(107))\n",
    "batch_size = 4\n",
    "\n",
    "# Number of subprocesses to be used in MultiThreadedAugmentor. It is important to\n",
    "# coordinate this variable between the dataloader you are defining \n",
    "# (which subclasses SlimDataLoaderBase, eg, we call it EpochDLWithShuffle here)\n",
    "# for calculations. Basically, pass the same value to both.\n",
    "num_threads_in_mt=12        \n",
    "\n",
    "# Just checking the samples in our original dataset\n",
    "print(f\"Total samples in Dataset: {len(data)}\")\n",
    "\n",
    "# Let's create an instance of our dataloader which we defined in the previous code cell\n",
    "dl = EpochDLWithShuffle(\n",
    "        data=data, \n",
    "        num_threads_in_mt=num_threads_in_mt,    # This should be the same as *\n",
    "        batch_size=batch_size\n",
    "        )\n",
    "\n",
    "\n",
    "# Multithreaded augmemter with no transforms. For actual datasets, feel free\n",
    "# to use multiple transforms from batchgenerators. The bare minimum you need to\n",
    "# be concerned about while using MultiThreadedAugmenter is the data_loader,\n",
    "# transforms and num of processes. The speedup is essentially due to individual\n",
    "# sub-processes performing the transforms on the data and loading them onto Queues\n",
    "# for your training loop to retrieve from when iterating over the \n",
    "# MultiThreadedAugmenter object\n",
    "mt = MultiThreadedAugmenter(\n",
    "            data_loader=dl, \n",
    "            transform=None,                      # batchgenerators provides augmentations for 2D and 3D data\n",
    "            num_processes=num_threads_in_mt      # * as this\n",
    "            )\n",
    "# For optimizing performance further, please read about the other parameters of\n",
    "# MultiThreadedAugmenter\n",
    "\n",
    "# Iterating over one epoch to demonstrate batches being retrieved\n",
    "for epoch in range(1):\n",
    "    l = []\n",
    "    for batch_id, batch in enumerate(mt):\n",
    "        print(f\"Batch {batch_id+1}:{batch}\")\n",
    "\n",
    "        # Let's store the batches to analyze them later\n",
    "        l.extend(batch)      \n",
    "\n",
    "    # How many samples were returned by the augmentor?\n",
    "    print(f\"Total samples returned by augmentor: {len(l)}\")\n",
    "\n",
    "    # How much of the original data was covered in the batches return?\n",
    "    print(f\"Coverage of samples returned by augmentor of original dataset: {len(set(l).intersection(set(range(len(data)))))}/{len(data)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Timing an epoch\n",
    "---\n",
    "The major speed benefits are seen while working with actual data and an augmentation\n",
    "pipeline (check `batchgenerators` README for an example of this). \n",
    "\n",
    "However, an example of timing code is provided and can be used to check for speedups when using real data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "This batch took 0.002 s\n",
      "This batch took 0.001 s\n",
      "This batch took 0.001 s\n",
      "This batch took 0.001 s\n",
      "This batch took 0.001 s\n",
      "This batch took 0.001 s\n",
      "This batch took 0.001 s\n",
      "This batch took 0.001 s\n",
      "This batch took 0.001 s\n",
      "This batch took 0.001 s\n",
      "This batch took 0.001 s\n",
      "This batch took 0.001 s\n",
      "This batch took 0.021 s\n",
      "This batch took 0.001 s\n",
      "This batch took 0.000 s\n",
      "This batch took 0.001 s\n",
      "This batch took 0.001 s\n",
      "This batch took 0.001 s\n",
      "This batch took 0.001 s\n",
      "This batch took 0.001 s\n",
      "This batch took 0.001 s\n",
      "This batch took 0.001 s\n",
      "This batch took 0.001 s\n",
      "This batch took 0.002 s\n",
      "This batch took 0.001 s\n",
      "This batch took 0.001 s\n",
      "This batch took 0.001 s\n",
      "Multi threaded batch generation using 12 workers took 0.001 s on average per batch\n"
     ]
    }
   ],
   "source": [
    "# Timing a full epoch. However, this is without transformations and speedups are\n",
    "# usually observed while using tranformations\n",
    "batch_times = []\n",
    "for _ in range(math.ceil(len(data)/batch_size)):\n",
    "    start = time.time()\n",
    "    _ = next(mt)\n",
    "    batch_times.append(time.time() - start)\n",
    "    print(\"This batch took %02.3f s\" % batch_times[-1])\n",
    "\n",
    "avg_batch_time_mt = np.mean(batch_times)\n",
    "print(\"Multi threaded batch generation using 12 workers took %02.3f s on average per batch\" % avg_batch_time_mt)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.2 64-bit ('saikat-deeplearn': conda)",
   "name": "python392jvsc74a57bd081c1cb16bfde5ea277679adc7b8e2552bbfaaf033586247b62514ceb2f47d4c4"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.2"
  },
  "metadata": {
   "interpreter": {
    "hash": "81c1cb16bfde5ea277679adc7b8e2552bbfaaf033586247b62514ceb2f47d4c4"
   }
  },
  "orig_nbformat": 2
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
